{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make sure run this notebook in electra repo folder after git clone,\n",
    "\n",
    "```bash\n",
    "git clone https://github.com/google-research/electra.git\n",
    "cd electra\n",
    "jupyter notebook\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-07-19 19:22:35--  https://storage.googleapis.com/electra-data/electra_base.zip\n",
      "Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.130.128, 172.217.194.128, 2404:6800:4003:c00::80, ...\n",
      "Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.130.128|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 885890161 (845M) [application/zip]\n",
      "Saving to: ‘electra_base.zip’\n",
      "\n",
      "electra_base.zip    100%[===================>] 844.85M  22.2MB/s    in 48s     \n",
      "\n",
      "2020-07-19 19:23:24 (17.5 MB/s) - ‘electra_base.zip’ saved [885890161/885890161]\n",
      "\n",
      "Archive:  electra_base.zip\n",
      "   creating: electra_base/\n",
      "  inflating: electra_base/electra_base.meta  \n",
      "  inflating: electra_base/electra_base.index  \n",
      "  inflating: electra_base/checkpoint  \n",
      "  inflating: electra_base/vocab.txt  \n",
      "  inflating: electra_base/electra_base.data-00000-of-00001  \n",
      "checkpoint\t\t\t  electra_base.index  vocab.txt\n",
      "electra_base.data-00000-of-00001  electra_base.meta\n"
     ]
    }
   ],
   "source": [
    "!wget https://storage.googleapis.com/electra-data/electra_base.zip\n",
    "!unzip electra_base.zip\n",
    "!ls electra_base/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'vocab_size': 30522,\n",
       " 'hidden_size': 768,\n",
       " 'num_hidden_layers': 12,\n",
       " 'num_attention_heads': 12,\n",
       " 'hidden_act': 'gelu',\n",
       " 'intermediate_size': 3072,\n",
       " 'hidden_dropout_prob': 0.1,\n",
       " 'attention_probs_dropout_prob': 0.1,\n",
       " 'max_position_embeddings': 512,\n",
       " 'type_vocab_size': 2,\n",
       " 'initializer_range': 0.02}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import configure_finetuning\n",
    "from util import training_utils\n",
    "\n",
    "hparams = {'model_size': 'base', 'vocab_size': 30522}\n",
    "config = configure_finetuning.FinetuningConfig('electra-base', './electra_base/', **hparams)\n",
    "bert_config = training_utils.get_bert_config(config)\n",
    "\n",
    "bert_config.__dict__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from model import modeling\n",
    "from model import optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_SEQ_LENGTH = 120"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model import tokenization\n",
    "\n",
    "tokenizer = tokenization.FullTokenizer(\n",
    "        vocab_file='electra_base/vocab.txt',\n",
    "        do_lower_case=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['train_X', 'train_Y', 'test_X', 'test_Y'])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "with open('../text.json') as fopen:\n",
    "    data = json.load(fopen)\n",
    "    \n",
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def _truncate_seq_pair(tokens_a, tokens_b, max_length):\n",
    "    while True:\n",
    "        total_length = len(tokens_a) + len(tokens_b)\n",
    "        if total_length <= max_length:\n",
    "            break\n",
    "        if len(tokens_a) > len(tokens_b):\n",
    "              tokens_a.pop()\n",
    "        else:\n",
    "              tokens_b.pop()\n",
    "\n",
    "def get_data(left, right):\n",
    "    input_ids, input_masks, segment_ids = [], [], []\n",
    "    for i in tqdm(range(len(left))):\n",
    "        tokens_a = tokenizer.tokenize(left[i])\n",
    "        tokens_b = tokenizer.tokenize(right[i])\n",
    "        _truncate_seq_pair(tokens_a, tokens_b, MAX_SEQ_LENGTH - 3)\n",
    "        tokens = []\n",
    "        segment_id = []\n",
    "        tokens.append(\"[CLS]\")\n",
    "        segment_id.append(0)\n",
    "        for token in tokens_a:\n",
    "            tokens.append(token)\n",
    "            segment_id.append(0)\n",
    "        tokens.append(\"[SEP]\")\n",
    "        segment_id.append(0)\n",
    "        for token in tokens_b:\n",
    "            tokens.append(token)\n",
    "            segment_id.append(1)\n",
    "        tokens.append(\"[SEP]\")\n",
    "        segment_id.append(1)\n",
    "        input_id = tokenizer.convert_tokens_to_ids(tokens)\n",
    "        input_mask = [1] * len(input_id)\n",
    "\n",
    "        while len(input_id) < MAX_SEQ_LENGTH:\n",
    "            input_id.append(0)\n",
    "            input_mask.append(0)\n",
    "            segment_id.append(0)\n",
    "\n",
    "        input_ids.append(input_id)\n",
    "        input_masks.append(input_mask)\n",
    "        segment_ids.append(segment_id)\n",
    "    return input_ids, input_masks, segment_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 261802/261802 [03:03<00:00, 1425.14it/s]\n"
     ]
    }
   ],
   "source": [
    "left, right = [], []\n",
    "for i in range(len(data['train_X'])):\n",
    "    l, r = data['train_X'][i].split(' <> ')\n",
    "    left.append(l)\n",
    "    right.append(r)\n",
    "    \n",
    "train_ids, train_masks, segment_train = get_data(left, right)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13395/13395 [00:09<00:00, 1446.81it/s]\n"
     ]
    }
   ],
   "source": [
    "left, right = [], []\n",
    "for i in range(len(data['test_X'])):\n",
    "    l, r = data['test_X'][i].split(' <> ')\n",
    "    left.append(l)\n",
    "    right.append(r)\n",
    "    \n",
    "test_ids, test_masks, segment_test = get_data(left, right)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from finetune import task_builder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'chunk'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tasks = task_builder.get_tasks(config)\n",
    "tasks[0].name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 60\n",
    "epoch = 10\n",
    "num_train_steps = int(len(train_ids) / batch_size * epoch)\n",
    "\n",
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dimension_output,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None])\n",
    "        \n",
    "        model = modeling.BertModel(\n",
    "            bert_config=bert_config,\n",
    "            is_training=False,\n",
    "            input_ids=self.X,\n",
    "            input_mask=self.input_masks,\n",
    "            token_type_ids=self.segment_ids,\n",
    "            use_one_hot_embeddings=False)\n",
    "        \n",
    "        output_layer = model.get_pooled_output()\n",
    "        \n",
    "        with tf.variable_scope(\"task_specific/classify\"):\n",
    "            self.logits = tf.layers.dense(output_layer, dimension_output)\n",
    "        \n",
    "        self.cost = tf.reduce_mean(\n",
    "            tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "                logits = self.logits, labels = self.Y\n",
    "            )\n",
    "        )\n",
    "        self.optimizer = optimization.create_optimizer(\n",
    "              self.cost, config.learning_rate, num_train_steps,\n",
    "              weight_decay_rate=config.weight_decay_rate,\n",
    "              use_tpu=config.use_tpu,\n",
    "              warmup_proportion=config.warmup_proportion,\n",
    "              layerwise_lr_decay_power=config.layerwise_lr_decay,\n",
    "              n_transformer_layers=bert_config.num_hidden_layers\n",
    "          )\n",
    "        \n",
    "        correct_pred = tf.equal(\n",
    "            tf.argmax(self.logits, 1, output_type = tf.int32), self.Y\n",
    "        )\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/text-similarity/electra/model/modeling.py:698: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/clip_ops.py:301: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    }
   ],
   "source": [
    "dimension_output = 2\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from electra_base/electra_base\n"
     ]
    }
   ],
   "source": [
    "var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'electra')\n",
    "saver = tf.train.Saver(var_list = var_lists)\n",
    "saver.restore(sess, 'electra_base/electra_base')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = ['contradiction', 'entailment']\n",
    "\n",
    "train_Y = data['train_Y']\n",
    "test_Y = data['test_Y']\n",
    "\n",
    "train_Y = [labels.index(i) for i in train_Y]\n",
    "test_Y = [labels.index(i) for i in test_Y]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop:   0%|          | 0/4364 [00:30<?, ?it/s].12it/s, accuracy=0.667, cost=0.656]\n",
      "train minibatch loop:   1%|          | 27/4364 [00:18<50:02,  1.44it/s, accuracy=0.55, cost=0.694]\n",
      "train minibatch loop: 100%|██████████| 4364/4364 [34:23<00:00,  2.11it/s, accuracy=0.955, cost=0.233] \n",
      "test minibatch loop: 100%|██████████| 224/224 [00:34<00:00,  6.49it/s, accuracy=1, cost=0.00656]   \n",
      "train minibatch loop:   0%|          | 0/4364 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.000000, current acc: 0.963170\n",
      "time taken: 2097.9571702480316\n",
      "epoch: 0, training loss: 0.158017, training acc: 0.938330, valid loss: 0.099504, valid acc: 0.963170\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 4364/4364 [34:21<00:00,  2.12it/s, accuracy=1, cost=0.0083]    \n",
      "test minibatch loop: 100%|██████████| 224/224 [00:34<00:00,  6.55it/s, accuracy=1, cost=0.00116]   "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 2095.7804176807404\n",
      "epoch: 1, training loss: 0.070698, training acc: 0.976784, valid loss: 0.148000, valid acc: 0.962723\n",
      "\n",
      "break epoch:2\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import time\n",
    "import numpy as np\n",
    "\n",
    "tqdm._instances.clear()\n",
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 1, 0, 0, 0\n",
    "\n",
    "while True:\n",
    "    lasttime = time.time()\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:%d\\n' % (EPOCH))\n",
    "        break\n",
    "\n",
    "    train_acc, train_loss, test_acc, test_loss = [], [], [], []\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_ids), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_ids))\n",
    "        batch_x = train_ids[i: index]\n",
    "        batch_masks = train_masks[i: index]\n",
    "        batch_segment = segment_train[i: index]\n",
    "        batch_y = train_Y[i: index]\n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.segment_ids: batch_segment,\n",
    "                model.input_masks: batch_masks\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        train_loss.append(cost)\n",
    "        train_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "    \n",
    "    pbar = tqdm(range(0, len(test_ids), batch_size), desc = 'test minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_ids))\n",
    "        batch_x = test_ids[i: index]\n",
    "        batch_masks = test_masks[i: index]\n",
    "        batch_segment = segment_test[i: index]\n",
    "        batch_y = test_Y[i: index]\n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.segment_ids: batch_segment,\n",
    "                model.input_masks: batch_masks\n",
    "            },\n",
    "        )\n",
    "        test_loss.append(cost)\n",
    "        test_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "\n",
    "    train_loss = np.mean(train_loss)\n",
    "    train_acc = np.mean(train_acc)\n",
    "    test_loss = np.mean(test_loss)\n",
    "    test_acc = np.mean(test_acc)\n",
    "    \n",
    "    if test_acc > CURRENT_ACC:\n",
    "        print(\n",
    "            'epoch: %d, pass acc: %f, current acc: %f'\n",
    "            % (EPOCH, CURRENT_ACC, test_acc)\n",
    "        )\n",
    "        CURRENT_ACC = test_acc\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "        \n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (EPOCH, train_loss, train_acc, test_loss, test_acc)\n",
    "    )\n",
    "    EPOCH += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
